package it.uniroma2.exp.task;

import it.uniroma2.dtk.dt.GenericDT;
import it.uniroma2.dtk.op.convolution.ShuffledCircularConvolution;
import it.uniroma2.tk.TreeKernel;
import it.uniroma2.util.math.ArrayMath;
import it.uniroma2.util.tree.Tree;

import java.io.BufferedReader;
import java.io.FileReader;
import java.util.HashMap;

import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterPoint.Builder;
import edu.berkeley.compbio.jlibsvm.binary.C_SVC;
import edu.berkeley.compbio.jlibsvm.kernel.KernelFunction;
import edu.berkeley.compbio.jlibsvm.labelinverter.StringLabelInverter;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassModel;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassModel.AllVsAllMode;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassModel.OneVsAllMode;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassProblemImpl;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassificationSVM;
import edu.berkeley.compbio.jlibsvm.scaler.NoopScalingModel;

public class QCTaskTester {
	
	/**
	 * @param args
	 * @throws Exception 
	 */
	public static void main(String[] args) throws Exception {
		QCTaskTester att = new QCTaskTester();
		att.run();
	}
	
	public void run() throws Exception {
		System.out.println("Loading training and testing set...");
		HashMap<Tree, String> trainExamples = loadExamples("/home/lorenzo/esperimenti/dtk/qc/data/smallTrain");
		HashMap<Tree, String> testExamples = loadExamples("/home/lorenzo/esperimenti/dtk/qc/data/smallTest");
		System.out.println(trainExamples.size()+" training examples, "+testExamples.size()+" testing examples");
		runDistributed(trainExamples, testExamples);
		runOriginal(trainExamples, testExamples);
	}
	
	@SuppressWarnings("unchecked")
	private void runOriginal(HashMap<Tree, String> trainExamples, HashMap<Tree, String> testExamples) {
		HashMap<Tree, Integer> exampleIds = new HashMap<Tree, Integer>(trainExamples.size());
		int i = 0;
		for (Tree t : trainExamples.keySet()) {
			exampleIds.put(t, i);
			i++;
		}
		System.out.println("Building parameters...");
		MultiClassificationSVM<String, Tree> svm = new MultiClassificationSVM<String, Tree>(new C_SVC<String, Tree>());
		MultiClassProblemImpl<String, Tree> artificialProblem = new MultiClassProblemImpl<String, Tree>(String.class, new StringLabelInverter(), trainExamples, exampleIds, new NoopScalingModel<Tree>());
		Builder<String, Tree> svmParamBuilder = new Builder<String, Tree>();
		TreeKernel.lambda = 0.4;
		svmParamBuilder.kernel = new TreeKernel();
		svmParamBuilder.eps = (float) 0.001;
		svmParamBuilder.cache_size = 2048;
		svmParamBuilder.allVsAllMode = AllVsAllMode.None;
		svmParamBuilder.oneVsAllMode = OneVsAllMode.Best;
		System.out.println("Training...");
		long time = System.currentTimeMillis();
		MultiClassModel<String, Tree> model = svm.train(artificialProblem, svmParamBuilder.build());
		System.out.println((System.currentTimeMillis()-time)/1000.0+" seconds");
		System.out.println("Testing...");
		double acc = 0;
		for (Tree t : testExamples.keySet()) {
			if (model.predictLabel(t).equals(testExamples.get(t)))
				acc++;
		}
		System.out.println("Accuracy: "+(acc/testExamples.size()));
	}
	
	@SuppressWarnings("unchecked")
	private void runDistributed(HashMap<Tree, String> originalTrainExamples, HashMap<Tree, String> originalTestExamples) throws Exception {
		System.out.println("Running distributed experiment...");
		GenericDT dt = new GenericDT(0, 4096, 0.4, ShuffledCircularConvolution.class);
		HashMap<double[], String> trainExamples = new HashMap<double[], String>(originalTrainExamples.size());
		HashMap<double[], String> testExamples = new HashMap<double[], String>(originalTestExamples.size());
		long time = System.currentTimeMillis();
		System.out.println("Distributing training trees...");
		for (Tree t : originalTrainExamples.keySet())
			trainExamples.put(dt.dt(t), originalTrainExamples.get(t));
		System.out.println("Distributing testing trees...");
		for (Tree t : originalTestExamples.keySet())
			testExamples.put(dt.dt(t), originalTestExamples.get(t));
		System.out.println((System.currentTimeMillis()-time)/1000.0+" seconds");
		HashMap<double[], Integer> exampleIds = new HashMap<double[], Integer>(trainExamples.size());
		int i = 0;
		for (double[] t : trainExamples.keySet()) {
			exampleIds.put(t, i);
			i++;
		}
		System.out.println("Building parameters...");
		MultiClassificationSVM<String, double[]> svm = new MultiClassificationSVM<String, double[]>(new C_SVC<String, double[]>());
		MultiClassProblemImpl<String, double[]> artificialProblem = new MultiClassProblemImpl<String, double[]>(String.class, new StringLabelInverter(), trainExamples, exampleIds, new NoopScalingModel<double[]>());
		Builder<String, double[]> svmParamBuilder = new Builder<String, double[]>();
		svmParamBuilder.kernel = new KernelFunction<double[]>() {
			public double evaluate(double[] x, double[] y) {
				try {
					return ArrayMath.dot(x, y);
				} catch (Exception e) {
					e.printStackTrace();
					return 0;
				}
			}
		};
		svmParamBuilder.eps = (float) 0.001;
		svmParamBuilder.cache_size = 2048;
		svmParamBuilder.allVsAllMode = AllVsAllMode.None;
		svmParamBuilder.oneVsAllMode = OneVsAllMode.Best;
		svmParamBuilder.oneVsAllThreshold = 0.1;
		System.out.println("Training...");
		time = System.currentTimeMillis();
		MultiClassModel<String, double[]> model = svm.train(artificialProblem, svmParamBuilder.build());
		System.out.println((System.currentTimeMillis()-time)/1000.0+" seconds");
		System.out.println("Testing...");
		double acc = 0;
		for (double[] t : testExamples.keySet()) {
			try {
				if (model.predictLabel(t).equals(testExamples.get(t)))
					acc++;
			}
			catch(Exception e) {
				System.out.println("No good label found!");
			}
		}
		System.out.println("Accuracy: "+(acc));
	}
	
	private HashMap<Tree, String> loadExamples(String file) throws Exception {
		HashMap<Tree, String> examples = new HashMap<Tree, String>();
		BufferedReader in = new BufferedReader(new FileReader(file));
		while (true) {
			String line = in.readLine();
			if (line == null)
				break;
			String label = line.substring(0, line.indexOf("|BT|")).trim(); 
			Tree t = Tree.fromPennTree(line.substring(line.indexOf("|BT|")+4, line.indexOf("|ET|")).trim());
			examples.put(t, label);
		}
		in.close();
		return examples;
	}
}
